# this is an implementation of NU-MBROLA written from scratch using numpy
# originally written by Baris Bozkurt, Michel Bagein and Thierry Dutoit 
# http://tcts.fpms.ac.be/publications/papers/2001/ssw4_bbmbtd.pdf
# https://github.com/numediart/MBROLA
# wav start end lenght [pitch...]
# see also https://github.com/yuanchao/tn_fnds_yc 
# and https://github.com/tuanad121/Python-WORLD

import os
import numpy as np
import soundfile as sf

default_samplerate = 44100
unvoiced_period = 50

import synth_common

class Phoneme:
	def __init__(self,wav,start,end,length):
		self.wav=wav
		self.start=float(start)
		self.end=float(end)
		self.length=float(length)
		self.pitch=[]
	def addpitch(self,p):
		self.pitch += [float(p)]

def readpho(fn):
	ret = []
	current_pos = 0
	for line in open(fn).read().split("\n"):
		if(len(line)):
			cmd = line.split(" ")
			pho = Phoneme(cmd[0],cmd[1],cmd[2],cmd[3])
			pho.left = current_pos
			current_pos += pho.length
			pho.right = current_pos
			for i in range(4,len(cmd)):
				pho.addpitch(cmd[i])
			ret += [pho]
	return ret
	
class VoiceRecord:
	def __init__(self,wav,f0,pmk):
		self.wav=wav
		self.f0=f0
		self.pmk=pmk

class Synth():
	def __init__(self,vbdir,phofile):
		self.vb = self.readvoicebank(vbdir)
		self.pho = readpho(phofile)
		self.valid = self.validateInput()
		if(self.valid):
			self.song_length = self.pho[len(self.pho)-1].right
			self.song_length_samples = int(self.song_length*default_samplerate/1000)
			self.ola_buffer = np.zeros(self.song_length_samples+1024)
			self.pitchpoints = self.calculatePitchPoints()
		else:
			print("input is invalid")
	def readvoicebank(self,path):
		ret = {}
		for f in os.listdir(path):
			filepath = os.path.join(path, f)
			if(os.path.isfile(filepath)):
				f2 = os.path.splitext(f)
				if(len(f2)==2 and f2[1]==".wav"):
					x, fs = sf.read(filepath)
					if(fs!=default_samplerate):
						print("ERROR: samplerate mismatch")
						return None
					h0,d0 = synth_common.readTrack(os.path.splitext(filepath)[0]+".f0")
					h1,d1 = synth_common.readTrack(os.path.splitext(filepath)[0]+".pmk")
					ret[f] = VoiceRecord(x,d0,d1)
		return ret
	def validateInput(self):
		for p in self.pho:
			if not p.wav in self.vb:
				print("ERROR: not_found "+p.wav)
				return False
			else:
				vr = self.vb[p.wav]
				start = int(p.start*default_samplerate/1000)
				end = int(p.end*default_samplerate/1000)
				if(start >= len(vr.wav) or end >= len(vr.wav)):
					print("ERROR: start or end out of range")
					return False
		return True
	def calculatePitchPoints(self):
		ret = []
		hz = 0
		for p in self.pho:
			l = len(p.pitch)
			for i in range(0,l,2):
				pos = p.pitch[i]*0.01*(p.right-p.left)+p.left
				hz = p.pitch[i+1]
				if(len(ret)==0):
					ret += [(0,hz)]
				ret += [(pos,hz)]
		ret += [(self.song_length,hz)]
		return ret
	def getPeriod(self,index):
		ms = index*1.0/default_samplerate*1000
		for i in range(0,len(self.pitchpoints)):
			if(ms > self.pitchpoints[i][0] and ms < self.pitchpoints[i+1][0]):
				p0 = self.pitchpoints[i]
				p1 = self.pitchpoints[i+1]
				a = ms-p0[0]
				b = p1[0]-ms
				f0 = (a*p0[1]+b*p1[1])/(a+b)
				#print("f0",f0)
				return default_samplerate/f0
		print("Synthesizer::getPeriod >> Not Found, using default value")
		return 100
	def getImpulse(self,index):
		ms = index*1.0/default_samplerate*1000
		for p in self.pho:
			if(ms > p.left and ms <= p.right):
				vr = self.vb[p.wav] 
				pos_ms = p.start+p.end*(ms-p.left)/(p.right-p.left)
				#print(pos_ms)
				#vr.wav vr.f0 vr.pmk
				f0 = synth_common.getPitch(vr.f0,vr.pmk,[pos_ms*0.001,1])
				imp = synth_common.getImpulse(vr.wav,default_samplerate,f0,vr.pmk,pos_ms*0.001)
				
				base_index = np.arange(0, len(imp))
				base_index = len(imp/2) + base_index + index
				safe_index = np.minimum(len(self.ola_buffer), np.maximum(1, base_index))
				safe_index = np.array(safe_index, dtype=np.int)
				
				return imp,safe_index
		return None,None
		

s = Synth("/home/isengaara/Hacking/Audio/QTAU/qtau/get_utau/TETO_110401","numbrola.pho")
	
if(s.valid):
	i = 0
	while(i<s.song_length_samples):
		period = s.getPeriod(i)
		impulse,index = s.getImpulse(i)
		if(not impulse is None):
			s.ola_buffer[index] += impulse
		
		i+=period
		
	sf.write('numbrola.wav', s.ola_buffer, default_samplerate)
